/*
 *  Python Module for Mean Shift Image Segmentation (PyMeanShift)
 *  Copyright (C) 2011 by Frederic Jean
 *
 *  PyMeanShift is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published
 *  by the Free Software Foundation, either version 3 of the License,
 *  or (at your option) any later version.
 *
 *  PyMeanShift is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with PyMeanShift.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

#include <Python.h>
#include <numpy/arrayobject.h>
#include <cstring>

#include "msImageProcessor.h"

extern "C" // Needed since this is compiled with a c++ compiler
{  

  // Module doc
  static char pmsDoc[] = \
    "Python Extension for Mean Shift Image Segmentation (PyMeanShift)\n\
     \n\
     ";

  // Segment image function doc
  static char pmsSegmentDoc[] = \
    "\n\
     \n\
     ";
  
  // Segment image function (the only function provided by the extension)
  static PyObject* segment(PyObject* self, PyObject* args)
  {
    PyObject* array = NULL;
    PyObject* inputImage = NULL;
    PyArrayObject* segmentedImage = NULL;
    PyArrayObject* labelImage = NULL;
    int sigmaS[1];
    double sigmaR[1];
    unsigned int minRegion[1];
    unsigned int speedUp[1] = { HIGH_SPEEDUP };
    
    
    msImageProcessor imageSegmenter;
    SpeedUpLevel speedUpLevel;    
    int* tmpLabels = NULL;
    float* tmpModes = NULL;
    int* tmpModePointCounts = NULL;
    int nbRegions;
    int dimensions[3];
    int nbDimensions;
    
    if (!PyArg_ParseTuple(args, "OidI|I", &array, &sigmaS, &sigmaR, &minRegion, &speedUp))
      return NULL;
    
    if(sigmaS[0] < 0)
    {
      PyErr_SetString(PyExc_ValueError, "Spatial radius must be greater or equal to zero");
      return NULL;
    }

    if(sigmaR[0] < 0.)
    {
      PyErr_SetString(PyExc_ValueError, "Range radius must be greater or equal to zero");
      return NULL;
    }
    
    if(speedUp[0] > 2)
    {
      PyErr_SetString(PyExc_ValueError, "Speedup level must be 0 (no speedup), 1 (medium speedup), or 2 (high speedup)");
      return NULL;
    }
    
    // Get ndarray object having 8 unsigned bits per element (uchar) and 
    inputImage = PyArray_FROM_OTF(array, NPY_UBYTE, NPY_IN_ARRAY);
    if(inputImage == NULL)
      return NULL;
    
    // Check that the array is 2 dimentional (gray scale image) or 3 dimensional (RGB color image),
    // and initialize segmenter
    if(PyArray_NDIM(inputImage) == 2)
    {
      nbDimensions = 2;
      dimensions[0] = PyArray_DIM(inputImage, 0);
      dimensions[1] = PyArray_DIM(inputImage, 1);
      imageSegmenter.DefineImage((unsigned char*)PyArray_DATA(inputImage), GRAYSCALE, dimensions[0], dimensions[1]);
    }
    else if(PyArray_NDIM(inputImage) == 3)
    {
      nbDimensions = 3;
      dimensions[0] = PyArray_DIM(inputImage, 0);
      dimensions[1] = PyArray_DIM(inputImage, 1);      
      dimensions[2] = 3;
      imageSegmenter.DefineImage((unsigned char*)PyArray_DATA(inputImage), COLOR, dimensions[0], dimensions[1]);
    }
    else
    {
      Py_DECREF(inputImage);
      PyErr_SetString(PyExc_ValueError, "Array must be 2 dimentional (gray scale image) or 3 dimensional (RGB color image)");
      return NULL;
    }
    
    // Create output images
    segmentedImage = (PyArrayObject *) PyArray_FromDims(nbDimensions, dimensions, PyArray_UBYTE);
    if(!segmentedImage)
      return NULL;  
        
    labelImage = (PyArrayObject *) PyArray_FromDims(2, dimensions, PyArray_INT);
    if(!labelImage)
      return NULL;  
    
    // Set speedup level
    switch(speedUp[0])
    {
      case 0:
	speedUpLevel = NO_SPEEDUP;
	break;
      case 1:
	speedUpLevel = MED_SPEEDUP;
	break;
      case 2:
	speedUpLevel = HIGH_SPEEDUP;
	break;      
      default:
	speedUpLevel = HIGH_SPEEDUP;
    }
    
    // Segment image and get segmented image
    imageSegmenter.Segment(sigmaS[0], sigmaR[0], minRegion[0], speedUpLevel);
    imageSegmenter.GetResults((unsigned char*)PyArray_DATA(segmentedImage));
    
    // Get labels images and number of regions
    nbRegions = imageSegmenter.GetRegions( &tmpLabels, &tmpModes, &tmpModePointCounts);
    memcpy((int*)PyArray_DATA(labelImage), tmpLabels, dimensions[0]*dimensions[1]*sizeof(int));
        
    // Cleanup
    Py_DECREF(inputImage);
    delete [] tmpLabels;
    delete [] tmpModes;
    delete [] tmpModePointCounts;    
    
    // Return a tuple with the segmented image, the label image, and the number of regions
    return Py_BuildValue("(NNi)", PyArray_Return(segmentedImage), PyArray_Return(labelImage), nbRegions) ;    
  }

  // Module methods definition
  static PyMethodDef pmsMethods[] = {
    {"segment", segment, METH_VARARGS, pmsSegmentDoc},
    {NULL, NULL}
  };

  // Module initialization function
  PyMODINIT_FUNC init_pymeanshift()
  {
    Py_InitModule3("_pymeanshift", pmsMethods, pmsDoc);
    import_array();
  }  
  
} 
